#!/usr/bin/env python
print "importing libraries"
import rosbag
import cv_bridge

import sm
from sm import PlotCollection
import aslam_cv as acv
import aslam_cameras_april as acv_april
import aslam_cv_backend as acvb
import kalibr_common as kc
import kalibr_camera_calibration as kcc

import cv2
import os
import numpy as np
import pylab as pl
import argparse
import sys
import random
import signal

np.set_printoptions(suppress=True)

def initBagDataset(bagfile, topic, from_to):
    print "\tDataset:          {0}".format(bagfile)
    print "\tTopic:            {0}".format(topic)
    reader = kc.BagImageDatasetReader(bagfile, topic, bag_from_to=from_to)
    print "\tNumber of images: {0}".format(reader.numImages())
    return reader

#available models
cameraModels = { 'pinhole-radtan': acvb.DistortedPinhole,
                 'pinhole-equi':   acvb.EquidistantPinhole,
                 'pinhole-fov':    acvb.FovPinhole,
                 'omni-none':      acvb.Omni,
                 'omni-radtan':    acvb.DistortedOmni,
                 'eucm-none':      acvb.ExtendedUnified,
                 'ds-none':        acvb.DoubleSphere}

def signal_exit(signal, frame):
    sm.logWarn("Shutdown requested! (CTRL+C)")
    sys.exit(2)

def parseArgs():
    class KalibrArgParser(argparse.ArgumentParser):
        def error(self, message):
            self.print_help()
            sm.logError('%s' % message)
            sys.exit(2)
        def format_help(self):
            formatter = self._get_formatter()
            formatter.add_text(self.description)
            formatter.add_usage(self.usage, self._actions,
                                self._mutually_exclusive_groups)
            for action_group in self._action_groups:
                formatter.start_section(action_group.title)
                formatter.add_text(action_group.description)
                formatter.add_arguments(action_group._group_actions)
                formatter.end_section()
            formatter.add_text(self.epilog)
            return formatter.format_help()     
        
    usage = """
    Example usage to calibrate a camera system with two cameras using an aprilgrid. 
    
    cam0: omnidirection model with radial-tangential distortion
    cam1: pinhole model with equidistant distortion
    
    %(prog)s --models omni-radtan pinhole-equi --target aprilgrid.yaml \\
              --bag MYROSBAG.bag --topics /cam0/image_raw /cam1/image_raw
    
    example aprilgrid.yaml:
        target_type: 'aprilgrid'
        tagCols: 6
        tagRows: 6
        tagSize: 0.088  #m
        tagSpacing: 0.3 #percent of tagSize"""
            
    parser = KalibrArgParser(description='Calibrate the intrinsics and extrinsics of a camera system with non-shared overlapping field of view.', usage=usage)
    parser.add_argument('--models', nargs='+', dest='models', help='The camera model {0} to estimate'.format(cameraModels.keys()), required=True)
    
    groupSource = parser.add_argument_group('Data source')
    groupSource.add_argument('--bag', dest='bagfile', help='The bag file with the data')
    groupSource.add_argument('--topics', nargs='+', dest='topics', help='The list of image topics', required=True)
    groupSource.add_argument('--bag-from-to', metavar='bag_from_to', type=float, nargs=2, help='Use the bag data starting from up to this time [s]')
    
    groupTarget = parser.add_argument_group('Calibration target configuration')
    groupTarget.add_argument('--target', dest='targetYaml', help='Calibration target configuration as yaml file', required=True)
    
    groupTarget = parser.add_argument_group('Image synchronization')
    groupTarget.add_argument('--approx-sync', dest='max_delta_approxsync', type=float, default=0.02, help='Time tolerance for approximate image synchronization [s] (default: %(default)s)')
    
    groupCalibrator = parser.add_argument_group('Calibrator settings')
    groupCalibrator.add_argument('--qr-tol', type=float, default=0.02, dest='qrTol', help='The tolerance on the factors of the QR decomposition (default: %(default)s)')
    groupCalibrator.add_argument('--mi-tol', type=float, default=0.2, dest='miTol', help='The tolerance on the mutual information for adding an image. Higher means fewer images will be added. Use -1 to force all images. (default: %(default)s)')
    groupCalibrator.add_argument('--no-shuffle', action='store_true', dest='noShuffle', help='Do not shuffle the dataset processing order')
    
    outlierSettings = parser.add_argument_group('Outlier filtering options')
    outlierSettings.add_argument('--no-outliers-removal', action='store_false', default=True, dest='removeOutliers', help='Disable corner outlier filtering')
    outlierSettings.add_argument('--no-final-filtering', action='store_false', default=True, dest='allowEndFiltering', help='Disable filtering after all views have been processed.')
    outlierSettings.add_argument('--min-views-outlier', type=int, default=20, dest='minViewOutlier', help='Number of raw views to initialize statistics (default: %(default)s)')
    outlierSettings.add_argument('--use-blakezisserman', action='store_true', dest='doBlakeZisserman', help='Enable the Blake-Zisserman m-estimator')
    outlierSettings.add_argument('--plot-outliers', action='store_true', dest='doPlotOutliers', help='Plot the detect outliers during extraction (this could be slow)')
    
    outputSettings = parser.add_argument_group('Output options')
    outputSettings.add_argument('--verbose', action='store_true', dest='verbose', help='Enable (really) verbose output (disables plots)')
    outputSettings.add_argument('--show-extraction', action='store_true', dest='showextraction', help='Show the calibration target extraction. (disables plots)')
    outputSettings.add_argument('--plot', action='store_true', dest='plot', help='Plot during calibration (this could be slow).')
    outputSettings.add_argument('--dont-show-report', action='store_true', dest='dontShowReport', help='Do not show the report on screen after calibration.')
       
    #print help if no argument is specified
    if len(sys.argv)==1:
        parser.print_help()
        sys.exit(2)
        
    #Parser the argument list
    try:
        parsed = parser.parse_args()
    except:
        sys.exit(2)
    
    #some checks
    if len(parsed.topics) != len(parsed.models):
        sm.logError("Please specify exactly one camera model (--models) for each topic (--topics).")
        sys.exit(2)
        
    if parsed.minViewOutlier<1:
        sm.logError("Please specify a positive integer (--min-views-outlier).")
        sys.exit(2)
    
    #there is a with the gtk plot widget, so we cant plot if we have opencv windows open...
    #--> disable the plots in these special situations
    if parsed.showextraction or parsed.verbose:
        parsed.dontShowReport = True
    
    return parsed


def main():
    parsed = parseArgs()
    
    #logging modes
    if parsed.verbose:
        sm.setLoggingLevel(sm.LoggingLevel.Debug)
    else:
        sm.setLoggingLevel(sm.LoggingLevel.Info)

    #register signal handler
    signal.signal(signal.SIGINT, signal_exit)

    targetConfig = kc.CalibrationTargetParameters(parsed.targetYaml)

    #create camera objects, initialize the intrinsics and extract targets
    cameraList = list()
    numCams = len(parsed.topics)

    obsdb = kcc.ObservationDatabase(parsed.max_delta_approxsync)
        
    for cam_id in range(0, numCams):
        topic = parsed.topics[cam_id]
        modelName = parsed.models[cam_id]
        print "Initializing cam{0}:".format(cam_id)
        print "\tCamera model:\t  {0}".format(modelName)

        if modelName in cameraModels:
            #open dataset 
            dataset = initBagDataset(parsed.bagfile, topic, parsed.bag_from_to)
        
            #create camera
            cameraModel = cameraModels[modelName]
            cam = kcc.CameraGeometry(cameraModel, targetConfig, dataset, verbose=(parsed.verbose or parsed.showextraction))
                
            #extract the targets
            multithreading = not (parsed.verbose or parsed.showextraction)
            observations = kc.extractCornersFromDataset(cam.dataset, cam.ctarget.detector, 
                                                        multithreading=multithreading, clearImages=False,
                                                        noTransformation=True)
            
            #populate the database
            for obs in observations:
                obsdb.addObservation(cam_id, obs)

            #initialize the intrinsics
            if not cam.initGeometryFromObservations(observations):
                raise RuntimeError("Could not initialize the intrinsics for camera with topic: {0}. Try to use --verbose and check whether the calibration target extraction is successful.".format(topic))
            
            print "\tProjection initialized to: %s" % cam.geometry.projection().getParameters().flatten()
            print "\tDistortion initialized to: %s" % cam.geometry.projection().distortion().getParameters().flatten()
            
            cameraList.append(cam)
        else:
            raise RuntimeError( "Unknown camera model: {0}. Try {1}.".format(modelName, cameraModels.keys()) )
        
    if parsed.verbose:
        obsdb.printTable()
    
    #initialize the calibration graph
    graph = kcc.MulticamCalibrationGraph(obsdb)
    
    if not graph.isGraphConnected():
        obsdb.printTable()
        print "Cameras are not connected through mutual observations, please check the dataset. Maybe adjust the approx. sync. tolerance."
        graph.plotGraph()
        sys.exit(-1)
       
    #loop to restart the optimization
    restartAttempts=3
    initOutlierRejection=True
    removedOutlierCorners=list() 
    while True:
        try:
            #compute initial guesses for the baselines, intrinsics
            print "initializing initial guesses"
            if len(cameraList)>1:
                baseline_guesses = graph.getInitialGuesses(cameraList)
            else:
                baseline_guesses=[]
                
            if parsed.verbose and len(cameraList)>1:
                graph.plotGraph()

            for baseline_idx, baseline in enumerate(baseline_guesses):
                print "initialized baseline between cam{0} and cam{1} to:".format(baseline_idx, baseline_idx+1)
                print baseline.T()
                
            for cam_idx, cam in enumerate(cameraList):
                print "initialized cam{0} to:".format(cam_idx)
                print "\t projection cam{0}: {1}".format(cam_idx, cam.geometry.projection().getParameters().flatten())
                print "\t distortion cam{0}: {1}".format(cam_idx, cam.geometry.projection().distortion().getParameters().flatten())
                

            print "initializing calibrator"
            calibrator = kcc.CameraCalibration(cameraList, baseline_guesses, verbose=parsed.verbose, useBlakeZissermanMest=parsed.doBlakeZisserman)
            options = calibrator.estimator.getOptions()
            options.infoGainDelta = parsed.miTol
            options.checkValidity = True
            options.verbose = parsed.verbose
            linearSolverOptions = calibrator.estimator.getLinearSolverOptions()
            linearSolverOptions.columnScaling = True
            linearSolverOptions.verbose = parsed.verbose
            linearSolverOptions.epsSVD = 1e-6
            #linearSolverOptions.svdTol = 0.0 #TODO
            #linearSolverOptions.qrTol = 0.0
            
            optimizerOptions = calibrator.estimator.getOptimizerOptions()
            optimizerOptions.maxIterations = 50
            optimizerOptions.verbose = parsed.verbose
            verbose = parsed.verbose
        
            doPlot = parsed.plot
            if doPlot:
                print "Plotting during calibration. Things may be very slow (but you might learn something)."

            #shuffle the views
            timestamps = obsdb.getAllViewTimestamps()
            if not parsed.noShuffle:
                random.shuffle(timestamps)

            #process all target views
            print "starting calibration..."
            numViews = len(timestamps)
            progress = sm.Progress2(numViews); progress.sample()
            for view_id, timestamp in enumerate(timestamps):
                
                #add new batch problem
                obs_tuple = obsdb.getAllObsAtTimestamp(timestamp)    
                est_baselines = list()
                for bidx, baseline in enumerate(calibrator.baselines):
                    est_baselines.append( sm.Transformation(baseline.T()) )
                T_tc_guess = graph.getTargetPoseGuess(timestamp, cameraList, est_baselines)
                               
                success = calibrator.addTargetView(obs_tuple, T_tc_guess)
                
                #display process
                if (verbose or (view_id % 25) == 0) and calibrator.estimator.getNumBatches()>0 and view_id>1:
                    print
                    print "------------------------------------------------------------------"
                    print
                    print "Processed {0} of {1} views with {2} views used".format(view_id+1, numViews, calibrator.estimator.getNumBatches())
                    print
                    kcc.printParameters(calibrator)
                    print
                    print "------------------------------------------------------------------"
                    
                #calibration progress
                progress.sample()
                
                #plot added views
                if success and doPlot:
                    recent_view = calibrator.views[-1]
                    cams_in_view = [obs_tuple[0] for obs_tuple in recent_view.rig_observations]
                    plotter = PlotCollection.PlotCollection("Added view (stamp: {0})".format(timestamp))
                    for cam_id in cams_in_view:
                        fig=pl.figure(view_id*5000+cam_id)
                        kcc.plotAllReprojectionErrors(calibrator, cam_id, fno=fig.number, noShow=True)                        
                        plotter.add_figure("cam{0}".format(cam_id), fig)
                    plotter.show()   
                
                # Look for outliers
                runEndFiltering = view_id==(len(timestamps)-1) and parsed.allowEndFiltering # run another filtering step at the end (over all batches)
                numActiveBatches = calibrator.estimator.getNumBatches()
                if ((success and numActiveBatches>parsed.minViewOutlier*numCams) or (runEndFiltering and numActiveBatches>parsed.minViewOutlier*numCams)) and parsed.removeOutliers: 
                    #create the list of the batches to check               
                    if initOutlierRejection:
                        #check all views after the min. number of batches has been reached
                        batches_to_check=range(0, calibrator.estimator.getNumBatches())
                        print;print
                        print "Filtering outliers in all batches..."
                        initOutlierRejection=False
                        progress_filter = sm.Progress2(len(batches_to_check)); progress_filter.sample()
                    elif runEndFiltering:
                        #check all batches again after all views have been processed
                        print;print
                        print "All views have been processed.\n\nStarting final outlier filtering..."
                        batches_to_check=range(0, calibrator.estimator.getNumBatches())
                        progress_filter = sm.Progress2(len(batches_to_check)); progress_filter.sample()
                    else:
                        #only check most recent view
                        batches_to_check = [ calibrator.estimator.getNumBatches()-1 ]
                    
                    #now check all the specified batches
                    batches_to_check.sort()
                    batches_to_check.reverse()
                    for batch_id in batches_to_check:
                        
                        #check all cameras in this batch
                        cornerRemovalList_allCams=list()
                        camerasInBatch = calibrator.views[batch_id].rerrs.keys()
                        for cidx in camerasInBatch:
                            
                            #calculate the reprojection errors statistics
                            corners, reprojs, rerrs = kcc.getReprojectionErrors(calibrator, cidx)        
                            me, se = kcc.getReprojectionErrorStatistics(rerrs)
                            se_threshold = 4.0*se #TODO: find good value 

                            #select corners to remove
                            cornerRemovalList=list()
                            for pidx, reproj in enumerate(rerrs[batch_id]):
                                if (not np.all(reproj==np.array([None,None]))) and (abs(reproj[0]) > se_threshold[0] or abs(reproj[1]) > se_threshold[1]):
                                    cornerRemovalList.append(pidx)
                                    
                                    #display the corners info
                                    if parsed.verbose or parsed.doPlotOutliers:
                                        sm.logInfo( "Outlier detected on view {4} with idx {5} (rerr=({0}, {1}) > ({2},{3}) )".format(reproj[0], reproj[1], se_threshold[0], se_threshold[1], view_id, pidx))
                                        sm.logInfo( "Predicted: {0}".format(calibrator.views[batch_id].rerrs[cidx][pidx].getPredictedMeasurement()) )
                                        sm.logInfo( "Measured: {0}".format(calibrator.views[batch_id].rerrs[cidx][pidx].getMeasurement()) )

                                    #store the outlier corners for plotting
                                    removedOutlierCorners.append( (cidx, calibrator.views[batch_id].rerrs[cidx][pidx].getMeasurement()) )
                            
                            #queue corners on this cam for removal
                            cornerRemovalList_allCams.append( (cidx, cornerRemovalList) )
                            
                            #plot the observation with the outliers
                            if len(cornerRemovalList)>0 and parsed.doPlotOutliers:                                
                                for cam_id, obs in calibrator.views[batch_id].rig_observations:
                                    if cam_id==cidx:
                                        gridobs = obs
                                fig=pl.figure(view_id*100+batch_id+cidx)                                
                                kcc.plotCornersAndReprojection(gridobs, reprojs[batch_id], cornerlist=cornerRemovalList, 
                                                               fno=fig.number, clearFigure=True, plotImage=True,
                                                               title="Removing outliers in view {0} on cam {0}".format(view_id, cidx))
                                pl.show()
    
                        #remove the corners (if there are corners to be removed)
                        removeCount = sum([len(removelist) for cidx, removelist in cornerRemovalList_allCams])
                        if removeCount>0:
                            original_batch = calibrator.views[batch_id]
                            new_batch = kcc.removeCornersFromBatch(original_batch, cornerRemovalList_allCams, useBlakeZissermanMest=parsed.doBlakeZisserman)
                            
                            #replace the original batch with the corrected
                            calibrator.estimator.removeBatch( calibrator.views[batch_id] )
                            calibrator.views[batch_id] = new_batch
                            rval = calibrator.estimator.addBatch( new_batch, False )
                            
                            #queue the batch for removal if the corrected batch was rejected
                            if not rval.batchAccepted:
                                sm.logDebug("corrected view rejected! removing from optimization...")
                                calibrator.views.remove( calibrator.views[batch_id] )
                            sm.logDebug("Removed {0} outlier corners on batch {1}".format(removeCount, batch_id))   
                        
                        #start and end filtering progress bar
                        if len(batches_to_check)>1:
                            progress_filter.sample()
                            
            #final output
            print
            print
            print ".................................................................."
            print
            print "Calibration complete."
            print
            if parsed.removeOutliers:
                sm.logWarn("Removed {0} outlier corners.".format(len(removedOutlierCorners)) )             
            print
            print "Processed {0} images with {1} images used".format(numViews, calibrator.estimator.getNumBatches())
            kcc.printParameters(calibrator)
            
            if parsed.verbose and len(calibrator.baselines)>1:
		        f=pl.figure(100006)
		        kcc.plotCameraRig(calibrator.baselines, fno=f.number, clearFigure=False)
		        pl.show()
            
            #write to file
            bagtag = parsed.bagfile.translate(None, "<>:/\|?*").replace('.bag', '', 1)
            resultFile = "camchain-" + bagtag + ".yaml"
            kcc.saveChainParametersYaml(calibrator, resultFile, graph)
            print "Results written to file: {0}".format(resultFile)
            
            #save results to file
            resultFileTxt = "results-cam-" + bagtag + ".txt"
            kcc.saveResultTxt(calibrator, filename=resultFileTxt)
            print "  Detailed results written to file: {0}".format(resultFileTxt)
            
            #generate report
            reportFile = "report-cam-" + bagtag + ".pdf"
            G=None; 
            if numCams>1: 
                G=graph
            kcc.generateReport(calibrator, reportFile, showOnScreen=not parsed.dontShowReport, graph=G, removedOutlierCorners=removedOutlierCorners);
            
        except kcc.OptimizationDiverged:
            restartAttempts-=1
            sm.logWarn("Optimization diverged possibly due to a bad initialization. (Do the models fit the lenses well?)")
            
            if restartAttempts==0:
                sm.logError("Max. attemps reached... Giving up...")
                break
            else:
                sm.logWarn("Restarting for a new attempt...")    
                
                #reinitialize the intrinsics
                for cam_id, cam in enumerate(cameraList):
                    print "Reinitialize the intrinsics for camera {0}".format(cam_id)
                    observations = obsdb.getAllObsCam(cam_id)
                    if not cam.initGeometryFromObservations(observations):
                        raise RuntimeError("Could not re-initialize the intrinsics for camera with topic: {0}".format(topic))
                    
                    print "\tProjection initialized to: %s" % cam.geometry.projection().getParameters().flatten()
                    print "\tDistortion initialized to: %s" % cam.geometry.projection().distortion().getParameters().flatten()
                
        else:
            break #normal exit

if __name__ == "__main__":
    main()
#     try:
#         main()
#     except Exception,e:
#         sm.logError("Exception: {0}".format(e))
#         sys.exit(-1)
